{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6.6 动量法\n",
    "小批量梯度下降法从训练样本的使用角度改善了随机梯度下降法的效果。但在追求极致的优化算法道路上，科学家们远未停止前进的步伐。他们不停的找问题，想新思路，进一步改善算法性能。前面三种算法都使用了梯度计算，能不能从这里入手改变呢。真有人就这么干了，而且效果不错。这就是本节要介绍的动量法了。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.6.1 基本思想\n",
    "\n",
    "动量法是一种在深度学习中使用的优化方法，核心思想是让模型的更新更加平稳，从而使学习更加顺畅。具体来说，动量法通过将当前的梯度信息与上一步的梯度信息进行加权平均来减少梯度的震荡。这可以通过使用动量因子 $\\beta$ 来实现。给定一个参数 $\\theta$ 和当前的梯度 $\\nabla_\\theta J(\\theta)$，我们可以使用动量法来更新 $\\theta$，公式如下：\n",
    "\n",
    "$$v_t = \\beta v_{t-1} + (1-\\beta)\\nabla_\\theta J(\\theta)$$\n",
    "$$\\theta = \\theta - \\alpha v_t$$\n",
    "\n",
    "其中，$\\alpha$ 是学习率，$v_t$ 是梯度的动量，$\\beta$ 是动量因子。在每一步中，我们将当前的梯度与上一步的梯度加权平均起来，并使用这个平均值来更新参数。当 $\\beta$ 很大时，动量较大，这意味着梯度更多地“记住”之前的信息，因此梯度的波动会减少。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.6.2 优缺点\n",
    "\n",
    "先说优点：\n",
    "- 动量法可以帮助算法在局部最优解附近更快地收敛。这是因为动量法可以帮助算法“跳过”局部最小值，它可以保持一定的加速度，使算法能够跨越那些低点。\n",
    "- 可以减少学习率的需求。因为动量法可以帮助算法跳过局部最小值，所以我们不需要设置过大的学习率来更快地收敛。这可以减少学习率过大导致的振荡的风险。\n",
    "- 可以帮助算法更快地从平凡的初始权重开始收敛。这是因为动量法可以帮助算法跳过局部最小值，因此算法可以更快地从平凡的初始权重开始收敛。\n",
    "\n",
    "缺点方面：\n",
    "- 动量法可能会使算法“超调”。如果我们设置的动量因子 $\\beta$ 过大，那么动量就会变得过大，这可能会导致算法“超调”。\n",
    "\n",
    "### 6.6.3 发展历史\n",
    "\n",
    "动量法是由深度学习先驱 Geoffrey Hinton 在 1986 年提出的。在当时，他提出了一种名为“快速动量法”的优化方法，用于解决深度神经网络训练过程中的梯度消失问题。在 1990 年代，动量法得到了广泛应用，并在许多研究中得到了证明。在 2000 年代，动量法在深度学习中得到了广泛应用，并成为了许多深度学习框架的默认优化方法。近年来，动量法也在不断发展，人们提出了许多改进版本，比如 Nesterov 动量法，AdaGrad 动量法等。这些改进版本在某些情况下表现得更好，因此也得到了广泛使用。\n",
    "\n",
    "### 6.6.4 代码示例\n",
    "\n",
    "在 PyTorch 中，你可以使用 torch.optim 库中的 SGD 类来实现动量法。下面是一个简单的例子，展示了如何使用动量法来优化一个简单的神经网络,我们同时对比了不使用动量法的情况。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEWCAYAAABhffzLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAAsTAAALEwEAmpwYAAAo2klEQVR4nO3deZwU5bX/8c8BhmVgQBxQWQUVUFlEHRNwV1xQE5d4f7khGpdfEuNNYowaRbO43Jv4uma7iTHXXzRqNFHDdUH9iQsoKhpRBEQFUQM46sCoMOCAyM65fzw1TDPOQM9MV1dP1/f9etWrq6urq07heOrpp546Ze6OiIikR7ukAxARkfxS4hcRSRklfhGRlFHiFxFJGSV+EZGUUeIXEUkZJX5JJTN73MzOzfW6zYzhaDOryvV2RXamQ9IBiGTLzD7NeFsKbAC2RO+/4+53Z7stdz8pjnVF2gIlfmkz3L1b3byZVQLfcvenGq5nZh3cfXM+YxNpS9TVI21eXZeJmU00sw+BO8ysp5k9ambLzWxVNN8/4zvPmtm3ovnzzOwFM/t1tO67ZnZSC9cdbGYzzGyNmT1lZn80s79leRz7Rfv6xMwWmNmpGZ+dbGZvRttdamY/ipb3io7tEzNbaWbPm5n+v5Yd0h+IFIs9gF2BPYELCH/bd0TvBwLrgJt28P0vAm8DvYBfAreZmbVg3XuAWUA5cC3wjWyCN7MS4P8DU4HdgIuAu81sWLTKbYTurDJgBDA9Wn4ZUAX0BnYHfgyoDovskBK/FIutwDXuvsHd17l7jbs/4O6fufsa4BfAUTv4/nvufqu7bwHuBPoQEmnW65rZQOAQ4Gp33+juLwCPZBn/GKAb8J/Rd6cDjwITos83AfubWXd3X+XuczOW9wH2dPdN7v68qwCX7IQSvxSL5e6+vu6NmZWa2Z/M7D0zWw3MAHYxs/ZNfP/Duhl3/yya7dbMdfsCKzOWAXyQZfx9gQ/cfWvGsveAftH8mcDJwHtm9pyZjY2W/wpYBEw1syVmdmWW+5MUU+KXYtGwlXsZMAz4ort3B46MljfVfZML1cCuZlaasWxAlt9dBgxo0D8/EFgK4O6vuPtphG6gh4D/iZavcffL3H0v4FTgUjMb17rDkGKnxC/FqozQr/+Jme0KXBP3Dt39PWA2cK2ZdYxa5V/O8usvA58BV5hZiZkdHX3379G2zjKzHu6+CVhN6NrCzL5kZvtE1xhqCcNbtza6B5GIEr8Uq98BXYAVwEvAE3na71nAWKAG+DkwiXC/wQ65+0ZCoj+JEPN/A+e4+1vRKt8AKqNuqwuj/QAMAZ4CPgVmAv/t7s/k7GikKJmuA4nEx8wmAW+5e+y/OESypRa/SA6Z2SFmtreZtTOz8cBphD55kYIRW+I3s9vN7GMzm5+x7Fdm9paZvW5mk81sl7j2L5KQPYBnCV0vNwL/5u6vJhqRSAOxdfWY2ZGEP/673H1EtOwEYLq7bzazGwDcfWIsAYiISKNia/G7+wxgZYNlUzNqqLwE9P/cF0VEJFZJFmn7v4QRD40yswsIt97TtWvXg/fdd99W73DlSnj3Xdh/f+jSpdWbExEpaHPmzFnh7r0bLo91VI+ZDQIerevqyVj+E6AC+Eo2t5dXVFT47NmzWx3PnDlQUQEPPABf+UqrNyciUtDMbI67VzRcnvcWv5mdB3wJGJfvmiLDonJXb7214/VERIpZXhN/NLztCuCoBvVM8qJbN+jfX4lfRNItzuGc9xLuJBwW1Ur/JqEsbhkwzczmmdn/i2v/Tdl3XyV+EUm32Fr87j6hkcW3xbW/bO27L9x1F7hDk9XWRaTFNm3aRFVVFevXr9/5ypITnTt3pn///pSUlGS1fuoevThsGKxeDR9+CH36JB2NSPGpqqqirKyMQYMG0fSzbCRX3J2amhqqqqoYPHhwVt9JXcmGulGh6u4Ricf69espLy9X0s8TM6O8vLxZv7BSm/gXLkw2DpFipqSfX839905d4u/XD7p3hwULko5ERCQZqUv8ZjBiBMyfv/N1RURyZd68eTz22GNJhwGkMPFDfeLXowhEJF+U+BM2YkSo2/PRR0lHIiJxqKysZN999+W8885j6NChnHXWWTz11FMcdthhDBkyhFmzZrFy5UpOP/10Ro0axZgxY3j99dcBuPbaazn33HM54ogj2HPPPXnwwQe54oorGDlyJOPHj2fTpk0AzJkzh6OOOoqDDz6YE088kerqagCOPvpoJk6cyBe+8AWGDh3K888/z8aNG7n66quZNGkSo0ePZtKkSVx77bX8+te/3hbziBEjqKyszCr21krdcE6A4cPD6/z5sMceycYiUtR++EOYNy+32xw9Gn73u52utmjRIu677z5uv/12DjnkEO655x5eeOEFHnnkEa6//noGDBjAgQceyEMPPcT06dM555xzmBfFunjxYp555hnefPNNxo4dywMPPMAvf/lLzjjjDKZMmcIpp5zCRRddxMMPP0zv3r2ZNGkSP/nJT7j99tsB2Lx5M7NmzeKxxx7juuuu46mnnuLf//3fmT17NjfddBMQTjAtjf2hhx5q1T9hKhP/iKhk3Pz5cNxxycYiIvEYPHgwI0eOBGD48OGMGzcOM2PkyJFUVlby3nvv8cADDwBw7LHHUlNTw+rVqwE46aSTKCkpYeTIkWzZsoXx48cDbPvu22+/zfz58zn++OMB2LJlC30ybgz6SlQF8uCDD6aysjLnsbdWKhP/brtB7966wCsSuyxa5nHp1KnTtvl27dpte9+uXTs2b968w7tcM9ctKSnZNlyy7rvuzvDhw5k5c+YOv9++fXs2b97c6DodOnRg69at295njsPfWeytlco+ftDIHpG0O+KII7j77rsBePbZZ+nVqxfdu3fP6rvDhg1j+fLl2xL/pk2bWLCTMeJlZWWsWbNm2/tBgwYxd+5cAObOncu7777bksNokVQn/gULNLJHJK2uvfZa5syZw6hRo7jyyiu58847s/5ux44duf/++5k4cSIHHHAAo0eP5sUXX9zhd4455hjefPPNbRd3zzzzTFauXMnw4cO56aabGDp0aGsPKWuxPoglV3L1IJZMf/oTXHghVFbCnnvmdNMiqbZw4UL222+/pMNIncb+3Zt6EEtqW/zRdRNeey3ZOERE8i3Vid9MiV9E0ie1ib+sDPbZR4lfRNIntYkf4IADcn9viYhIoUt14h89GhYvDg9mERFJi9QnfoCoRIeISCoo8aPuHhGpN2jQIFasWJF0GLFKdeLv2xd69VLiF5F0SWWtnjpmodWvxC9SXNauXctXv/pVqqqq2LJlCz/72c8oKyvj0ksvpWvXrhx22GEsWbKERx99lJqaGiZMmMDSpUsZO3YsbeGm1tZKdeKHkPj/8AfYvBk6pP5fQyS3kqrK/MQTT9C3b1+mTJkCQG1tLSNGjGDGjBkMHjyYCRMmbFv3uuuu4/DDD+fqq69mypQp3HbbbbkNuACluqsHwpDODRvg7beTjkREcmXkyJFMmzaNiRMn8vzzz/Puu++y1157MXjwYIDtEv+MGTM4++yzATjllFPo2bNnIjHnU+rbuJkXeOse0CIiuZFUVeahQ4cyd+5cHnvsMX76058ybty4ZAIpUKlv8Q8bBp06qZ9fpJgsW7aM0tJSzj77bC6//HL+8Y9/sGTJkm0PMZk0adK2dY888kjuueceAB5//HFWrVqVRMh5lfoWf0lJKNGsxC9SPN544w0uv/zybQ9Sufnmm6murmb8+PF07dqVQw45ZNu611xzDRMmTGD48OEceuihDBw4MMHI8yO2xG9mtwNfAj529xHRsl2BScAgoBL4qrsnfnodPRoefjjU5o8etCMibdiJJ57IiSeeuN2yTz/9lLfeegt353vf+x4VFaFacXl5OVOnTk0izMTE2dXzF2B8g2VXAk+7+xDg6eh94kaPhhUrYNmypCMRkbjceuutjB49muHDh1NbW8t3vvOdpENKTGwtfnefYWaDGiw+DTg6mr8TeBaYGFcM2aq7wPvqq9CvX6KhiEhMLrnkEi655JKkwygI+b64u7u7V0fzHwK753n/jRo9OnTxzJmTdCQixSENN0EVkub+eyc2qsdDpE1Ga2YXmNlsM5u9fPnyWGPp1g322w9y/HRHkVTq3LkzNTU1Sv554u7U1NTQuXPnrL+T71E9H5lZH3evNrM+wMdNrejutwC3QHjmbtyBVVTA1Km6wCvSWv3796eqqoq4G2xSr3PnzvTv3z/r9fOd+B8BzgX+M3p9OM/7b1JFBdx1V7jAq35+kZYrKSnZdoesFKbYunrM7F5gJjDMzKrM7JuEhH+8mf0TOC56XxCikV3q7hGRohfnqJ4JTXyU33unt2yB9u13utoBB4TVZs+G007LQ1wiIgkp7pINV1yRdQGe0tKwqlr8IlLsijvxl5eHsptZ1t6oqAiJX4MRRKSYFXfiP+ig8Dp3blarV1SEO3jffz/GmEREElbcif/AA8NrMxI/qLtHRIpbcSf+Xr1Cd8+SJVmtPmpUqNapxC8ixay4Ez+EQflLl2a1aqdOMHKkEr+IFDcl/gZ0gVdEip0SfwMVFfDJJ7B4cXwhiYgkKR2J/+OPYdOmrFavezDPK6/EGJOISILSkfjdobp65+sSHsNYWgovvRRzXCIiCSn+xL/bbuE1y0qBHTqEVv/MmTHGJCKSoOJP/OXl4bWmJuuvjBkTnsa1bl1MMYmIJEiJvxFjx8LmzVnf9yUi0qYUf+Lv1Su8rliR9VfGjg2v6u4RkWJU/Im/Z8/w2owW/267wV57KfGLSHEq/sTfoUNI/s1I/BD6+WfO1I1cIlJ8ij/xQ+jnb0ZXD4Tunupq+OCDmGISEUlIOhJ/r17NbvGrn19EilU6Ev8uu4Q6DM0wahR06aLELyLFJx2Jv0cPqK1t1ldKSkLdHiV+ESk26Uj83bvD6tXN/trYseFGrvXrY4hJRCQh6Uj8LWjxQ0j8mzbpRi4RKS7pSfzr1mVdobNO3QXef/wjhphERBKSjsTfvXt4bWarf/fdYehQmDEjhphERBKSjsTfo0d4bUE//1FHwfPPw5YtOY5JRCQh6Ur8LejnP/LI8LU33shxTCIiCUlH4m9hVw+EFj+ou0dEikc6En8runoGDIBBg+C553IbkohIUhJJ/GZ2iZktMLP5ZnavmXWOdYet6OqB0OqfMUMF20SkOOQ98ZtZP+AHQIW7jwDaA1+Ldaet6OqB0M+/YgUsXJjDmEREEpJUV08HoIuZdQBKgWWx7q2VLf4jjwyv6ucXkWKQ98Tv7kuBXwPvA9VArbtPbbiemV1gZrPNbPbyLB+U3qROncLUgj5+gL33hr59lfhFpDgk0dXTEzgNGAz0Bbqa2dkN13P3W9y9wt0revfu3fodd+/e4ha/WWj1P/ec+vlFpO1LoqvnOOBdd1/u7puAB4FDY99rC+v11DnqKFi2DJYsyWFMIiIJSCLxvw+MMbNSMzNgHBD/ZdMePVrc1QP1/fzPPpubcEREkpJEH//LwP3AXOCNKIZbYt9xK7p6APbbL9TumT49hzGJiCSgQxI7dfdrgGvyutMePWDRohZ/3QzGjYOnnw79/GY5jE1EJI/ScecutLqPH+C44+Cjj2DBghzFJCKSgPQk/hY+hSvTuHHh9amnchCPiEhC0pf4WzEec+DAUJ9fiV9E2rL0JP6yspD0161r1WaOOy6M7Gnmw7xERApGehJ/t27hdc2aVm1m3DhYuxZefjkHMYmIJCA9ib+sLLx++mmrNnPMMWFEj7p7RKStSk/iz1GLv2dPqKiAadNyEJOISALSk/hz1OIHOOGE0NWzalWrNyUiknfpSfw5avEDnHxyePj61M/VFBURKXzpSfw5bPF/8Yuw667w2GOt3pSISN6lJ/HnsMXfvj2MHw+PPw5bt7Z6cyIieZWexJ/DFj+E7p7ly2H27JxsTkQkb9KX+HPQ4ofQ4jdTd4+ItD3pSfwdO0JJSc5a/OXlMGYMTJmSk82JiORNehI/hFZ/jlr8ELp7Zs8OFTtFRNqKdCX+bt1y1uIHOOWU8PrEEznbpIhI7NKV+HPc4h89Gvr0UXePiLQt6Ur8OW7xm8FJJ8GTT8LGjTnbrIhIrNKV+HPc4gc4/fRQ5v+ZZ3K6WRGR2KQr8ee4xQ9w/PHQtStMnpzTzYqIxCZdiT+GFn/nzmF0z0MPhfo9IiKFLl2JP4YWP8AZZ4QhnS+9lPNNi4jkXFaJ38y6mlm7aH6omZ1qZiXxhhaDGFr8EFr8JSXw4IM537SISM5l2+KfAXQ2s37AVOAbwF/iCio23brBhg05f2Bujx7hWbyTJ7fqWe4iInmRbeI3d/8M+Arw3+7+f4Dh8YUVkxwXast0xhnw7rvw2ms537SISE5lnfjNbCxwFlB3u1L7eEKKUQ5LMzd02mnQrh3cd1/ONy0iklPZJv4fAlcBk919gZntBbS9kesxtvh32w3GjYO//13dPSJS2LJK/O7+nLuf6u43RBd5V7j7D1q6UzPbxczuN7O3zGxh9GsifjG2+AEmTIAlS+CVV2LZvIhITmQ7quceM+tuZl2B+cCbZnZ5K/b7e+AJd98XOABY2IptZS/GFj+Efv6OHeHee2PZvIhITmTb1bO/u68GTgceBwYTRvY0m5n1AI4EbgNw943u/klLttVsMbf4d9kl1O6ZNEk3c4lI4co28ZdE4/ZPBx5x901AS3uyBwPLgTvM7FUz+3P0S2I7ZnaBmc02s9nLly9v4a4aiLnFD6G7p7oaZsyIbRciIq2SbeL/E1AJdAVmmNmewOoW7rMDcBBws7sfCKwFrmy4krvf4u4V7l7Ru3fvFu6qgZhb/ABf/nKo3aPuHhEpVNle3L3R3fu5+8kevAcc08J9VgFV7v5y9P5+wokgfnlo8ZeWhqGdDzygUs0iUpiyvbjbw8x+W9f1Yma/IbT+m83dPwQ+MLNh0aJxwJst2VazlZaG1xhb/BC6e1auhKlTY92NiEiLZNvVczuwBvhqNK0G7mjFfi8C7jaz14HRwPWt2Fb22rWLrVBbphNOgF694K67Yt2NiEiLdMhyvb3d/cyM99eZ2byW7tTd5wEVLf1+q3TrFnuLv2NHOOssuPlmqKmB8vJYdyci0izZtvjXmdnhdW/M7DBgXTwhxaysLPYWP8D554c+/nvuiX1XIiLNkm3ivxD4o5lVmlklcBPwndiiilNMpZkbOuAAOPBAuKM1HWIiIjHIdlTPa+5+ADAKGBUNwzw21sjikoc+/jrnnw+vvqqKnSJSWJr1BC53Xx3dwQtwaQzxxC9PLX6Ar389PKBFrX4RKSStefSi5SyKfMpji7+8PIzp/+tfYf36vOxSRGSnWpP422bx4Ty2+AEuvDCM6VedfhEpFDtM/Ga2xsxWNzKtAfrmKcbcymOLH+DYY2Ho0DC0U0SkEOww8bt7mbt3b2Qqc/ds7wEoLHXDOfP0tBSz0OqfORPmzcvLLkVEdqg1XT1tU7duIel/9lnednneedCli1r9IlIY0pf46wq15bGfv2dP+NrX4G9/g9ravO1WRKRR6Uv8daWZ89jPD/Dd74YfGbffntfdioh8TvoSfwItfoCKCjj8cPj972Hz5rzuWkRkO+lL/Am1+AF+9CN4771Qq19EJCnpS/wJtfghPJ1ryBD4zW/yNqhIRORz0pf4E2zxt2sHl1wCr7wCL7yQ992LiABpTPwJtvgBzj03lHL45S8T2b2ISAoTf4ItfghPf/zBD+DRR0PlThGRfEtv4k+oxQ8h8ffoAf/xH4mFICIplr7E37FjmBJq8QPssgtcfDFMngyvv55YGCKSUulL/JD3Cp2N+eEPQxhq9YtIvqUz8ee5QmdjevYMXT73368ndIlIfqUz8RdAix/gssvCCeDKK5OORETSJJ2JvwBa/BCS/o9/DE88AdOnJx2NiKRFOhN/gbT4Ab7/fRg4EK64ArZuTToaEUmDdCb+AmnxA3TuDD//OcyZA5MmJR2NiKRBOhN/AbX4Ac46Cw48MLT6165NOhoRKXbpTfwF0uKHUMPnppugqiq0/kVE4pRY4jez9mb2qpk9mvedd+tWUC1+gEMPDXV8fvMbePvtpKMRkWKWZIv/YmBhInsuK4ONG8NUQG64IdTyuegilW0WkfgkkvjNrD9wCvDnJPafdKG2puy+e+jqmTYN/vrXpKMRkWKVVIv/d8AVQJMDGM3sAjObbWazly9fntu915VmLrDED+HZvIcfHmr5LFuWdDQiUozynvjN7EvAx+4+Z0frufst7l7h7hW9e/fObRAFUKGzKe3ahQeyr18P3/mOunxEJPeSaPEfBpxqZpXA34FjzexveY2ggFv8EB7P+ItfhJr9d92VdDQiUmzynvjd/Sp37+/ug4CvAdPd/ey8BlHALf46F18MRxwB3/sevPNO0tGISDFJ7zh+KNgWP0D79nD33dCpE3zta7BhQ9IRiUixSDTxu/uz7v6lvO+4DbT4AQYMgDvuCI9ovPzypKMRkWKhFn+BO/XU0O3zhz9oiKeI5EY6E38bafHX+dWv4Oij4dvfhlmzko5GRNq6dCb+0lIwaxMtfoCSErjvPujTB844Q+P7RaR10pn4zQqyXs+O9OoFDz8MtbVw0knhVUSkJdKZ+KHgSjNnY9QomDwZFi6E004LN3mJiDRXehN/jx5tstl8/PFw553w3HMwYULB1ZkTkTYgvYm/Z09YuTLpKFpkwoQwyuehh+CrX1XyF5HmSXfiX7Uq6Sha7PvfDw9vefhhOPNM3eAlItlLb+Lfddc2nfghlHO4+eZQ0+fUU9vcJQsRSUh6E38b7urJdOGFoZrn00+H2j5VVUlHJCKFLt2Jf/Vq2LIl6Uha7fzzYcoUWLIEvvhFmDcv6YhEpJClN/Hvumt4/eSTRMPIlRNPhBdeCPX8DztM5R1EpGnpTfw9e4bXNt7Pn2nUqFDSoaICzjkHLrhAY/1F5POU+Iugnz9Tnz6hv/+qq+DWW0PXz2uvJR2ViBSS9Cb+Xr3Ca66f51sAOnSA668P/f4ffgiHHBIe4r5pU9KRiUghSG/i33338PrRR9svX78+3Bn12Wd5DynXTj4ZFiwI4/x/9rPQ+n/ppaSjEpGkKfE3TPynnx5KYJ53Xr4jikWvXnDvvXD//eFQx44No4AaHraIpEd6E39paSjUlpkBX3sNnnwyzN93X6iGViTOPBPeeguuuCI80nHIELj22jCiVUTSJb2JH0Kr/8MP698/9FAo2bxgQXjobZGNiSwrgxtugDfegBNOgOuug8GDw7K1a5OOTkTyRYk/s8X/7LMwejTsv3+4Dfaxx5KKLFbDhoWun9mzYcwYuPJKGDgQfvpTqK5OOjoRiVu6E/8ee9S3+DdsCFc+jzoqvB8/PnT9FPHjrg4+OIz8efHFcNjXXw977hkub7z8MrgnHaGIxCHdiX/QIKishK1bw51P69dvn/gBpk5NKrq8GTsWHnwQ/vnPUPvn/vvDL4GRI+G3vy3KEa8iqZbuxL/PPiHZL1sWnmwCoYsHwm2we+wBTzyRXHx5tvfecOON4Z/jllvCNYHLLgs3hZ1wQrghTCcBkbYv3Yl/yJDwumhRSPwjR0J5eVhmFlr9U6cWRSG35ujeHb79bZg5E+bPDyOB3n03lIDYYw8YNw5+85vwmbqDRNqedCf+ffYJrwsW1Hd0Zxo/PtTymTUr/7EViOHDQ9//O++ESx4//nG4LPKjH4XzZL9+4ZrAPffA++8nHa2IZMO8DTTZKioqfPbs2bnf8NatYWRPx46hf+P++8OA9zo1NbDbbmG4y3XX5X7/bVhVVfgxNHUqTJtWX/KoXz849NAwjRkTesxKS5ONVSStzGyOu1d8bnmqEz+EMpZ//Wt4+PqyZZ/PUmPHhhPEyy/Hs/8isGVL+DXw4ov103vvhc/atQs9agccEKa60bIDBoRbJUQkPk0l/g4JBDIAuAvYHXDgFnf/fb7j2ObCC0NXz9lnN940PeUUuPrq8JSTvfbKf3xtQPv2cNBBYfr+98OypUtDD9lrr4XplVfgf/6n/judOoWLyUOHhmnIkHAz2cCB0L8/dOmSzLGIpEHeW/xm1gfo4+5zzawMmAOc7u5vNvWdWFv8O7N0aRjc/t3vhiEvEK5ozpsX6h0cfDB065ZMbG1MbS28/jq8/Xa4ZvDOO2EI6aJFsHHj9uv26hV+FQwcGF779Qu9cr17h963ukndSCJNK5gWv7tXA9XR/BozWwj0A5pM/Inq1w++9S344x/DCWDNGvjb32Dx4vB5jx5w6aUwcWJoxkqTevQIo2XrRszW2bIlXBiurIQPPqif3n8//DM/+2w4aTSma9dwAujdOzxUbZdddj517x6+17WrupsknRLt4zezQcAMYIS7N1kuLNEWP4Ss8+Uvw/PPh2GexxwDX/869O0bBrdPnhw6rv/yl1D8XnJu7dpwD8HHH9dPDd+vWhWepFk3ZfP8gc6dwwmgW7f6k0FT8126hPUzp8aWNbVcJxnJt4K7uGtm3YDngF+4+4ONfH4BcAHAwIEDD36v7mphUrZuDX0U5eWhiZnp8cfDwPfqarj88lD2snPnRMKUwB3WrQsngNra7U8ItbXhRLJ2LXz66efnG1u2di1s3ty6mDp0qD8pdOoUBpN17Lj9fGPvs1mnJd+pm0pK6iez1v/bS+EoqMRvZiXAo8CT7v7bna2feIs/G7W1YXD7n/8crlROmBCqofXtG2597d8/NBulzdq8Odzo3di0bl3zP9u4cftpw4adL6t739qTUFM6dNj+RNDwxNDwfTbr5Pr9jtbp0EEnr0wFk/jNzIA7gZXu/sNsvtMmEn+dqVPDKKBZs7a/rdUM9tsvDA896qgwDRyYXJzSpm3dGrqymnOyaPh+w4awjcxp48adL2vN+3ykmw4dWnYyyTzpNTwBtuSzXK3XmpNZISX+w4HngTeArdHiH7t7kzWQ21Tir7NmTbhCWV0dpsWLw8lg5szQGQ2hSNwRR4TbY4cMCeMa995bYxmlaG3ZEv/JpaUnsM2bdz5f937r1p0fa65MmRIeo9oShTSq5wWg+H+MlZWFC77777/98q1bw5NQnnsuDFeZNm37B76YhW6hffapn/beu/5VQ0elDWvfPrRr2nrbpu4X185OEM39rLH1hg3Lffy6c7cQrF4dBrO/8064gLx4cXi/aNHny2Huscf2J4O99goD3fv3D9cTNKRURCIF09XTEkWf+Hdk9ertTwSLFtW/X7r08+v37h1OAnVTv371J4Xddgt3QfXqFToORaSoFUxXjzRT9+5w4IFhauizz8JdT0uXhqppVVX18++/H4rm1NQ0vt3y8nASqLsFNnO+d2/o2TPcEbXrrmG+tFTDJUSKhBJ/W1Za2vh1hEzr1oWTQXV1/Z1OH320/fy8eeG1qdtjIQyDqDsJZJ4Q6uZ79AgnqYZTWVn9rbLt0l0FXKRQKPEXuy5d6i8S78yGDeFkUFMT6iyvWhVeG5v/4INQeGflyjCCaWfM6k8CTZ0c6ua7ddv51LmzfoGItJASv9Tr1ClcKB4woHnf27QpJP/Vq5s31daGE8jq1eH7a9ZkP9C7XbvsThCZU10Nhh19rroKkgJK/NJ6JSX1XT6tsXXr9rUSWjJ99FG4+J25rDmPzuzSJZwESkvD1KXLjud39nlj81266AQjiVLil8LRrl3o6ikry9023cOdOs09gXz2WZjWrQuvtbXhmZOZyz77LHSPtUSnTvUngR1Vdsv1MnWRCUr8UuzMQpLt1CmMZMq1LVtC4Z2GJ4TG5hsuW7u26eI+tbWNF/xZt671dQ86dcruBLGjKm/ZVpnL9jsqspNXSvwirdG+fX3d5nxwr68W11j1t1wsq6nZvopcZrGfDRta/itnR8xyd1LZ2WcNi/TsaGpYSCdzasOj1JT4RdoSs/rEk8suseZwD790dlQNrqn5XKxXW7vz9bJ5GENrtWvXshNGc7/zjW+EWl45pMQvIs1jVl82slBLjddd22l4gtiw4fMFcbKZcvWdDRvCNaTmfOeww5T4RUR2KvPaTlsWU0kdJX4RkUIV0wXvtnt1QkREWkSJX0QkZZT4RURSRolfRCRllPhFRFJGiV9EJGWU+EVEUkaJX0QkZZT4RURSRolfRCRllPhFRFJGiV9EJGWU+EVEUkaJX0QkZRJJ/GY23szeNrNFZnZlEjGIiKRV3hO/mbUH/gicBOwPTDCz/fMdh4hIWiXR4v8CsMjdl7j7RuDvwGkJxCEikkpJPIGrH/BBxvsq4IsNVzKzC4ALorefmtnbLdxfL2BFC7/bVumY00HHnA6tOeY9G1tYsI9edPdbgFtaux0zm+3uFTkIqc3QMaeDjjkd4jjmJLp6lgIDMt73j5aJiEgeJJH4XwGGmNlgM+sIfA14JIE4RERSKe9dPe6+2cy+DzwJtAdud/cFMe6y1d1FbZCOOR10zOmQ82M2d8/1NkVEpIDpzl0RkZRR4hcRSZmiTvzFWhrCzG43s4/NbH7Gsl3NbJqZ/TN67RktNzO7Mfo3eN3MDkou8pYxswFm9oyZvWlmC8zs4mh50R4zgJl1NrNZZvZadNzXRcsHm9nL0fFNigZJYGadoveLos8HJXoALWRm7c3sVTN7NHpf1McLYGaVZvaGmc0zs9nRstj+vos28Rd5aYi/AOMbLLsSeNrdhwBPR+8hHP+QaLoAuDlPMebSZuAyd98fGAN8L/pvWczHDLABONbdDwBGA+PNbAxwA/Bf7r4PsAr4ZrT+N4FV0fL/itZriy4GFma8L/bjrXOMu4/OGLMf39+3uxflBIwFnsx4fxVwVdJx5fD4BgHzM96/DfSJ5vsAb0fzfwImNLZeW52Ah4HjU3bMpcBcwl3uK4AO0fJtf+eEkXJjo/kO0XqWdOzNPM7+UZI7FngUsGI+3ozjrgR6NVgW29930bb4abw0RL+EYsmH3d29Opr/ENg9mi+qf4fo5/yBwMuk4Jijbo95wMfANGAx8Im7b45WyTy2bccdfV4LlOc14Nb7HXAFsDV6X05xH28dB6aa2ZyoXA3E+PddsCUbpOXc3c2s6Mbpmlk34AHgh+6+2sy2fVasx+zuW4DRZrYLMBnYN9mI4mNmXwI+dvc5ZnZ0wuHk2+HuvtTMdgOmmdlbmR/m+u+7mFv8aSsN8ZGZ9QGIXj+OlhfFv4OZlRCS/t3u/mC0uKiPOZO7fwI8Q+jq2MXM6hptmce27bijz3sANfmNtFUOA041s0pC1d5jgd9TvMe7jbsvjV4/Jpzgv0CMf9/FnPjTVhriEeDcaP5cQj943fJzopEAY4DajJ+PbYKFpv1twEJ3/23GR0V7zABm1jtq6WNmXQjXNRYSTgD/Eq3W8Ljr/j3+BZjuUSdwW+DuV7l7f3cfRPj/dbq7n0WRHm8dM+tqZmV188AJwHzi/PtO+qJGzBdMTgbeIfSL/iTpeHJ4XPcC1cAmQv/eNwl9m08D/wSeAnaN1jXC6KbFwBtARdLxt+B4Dyf0gb4OzIumk4v5mKPjGAW8Gh33fODqaPlewCxgEXAf0Cla3jl6vyj6fK+kj6EVx3408Ggajjc6vteiaUFdrorz71slG0REUqaYu3pERKQRSvwiIimjxC8ikjJK/CIiKaPELyKSMkr8IoCZbYkqI9ZNOavmamaDLKOSqkjSVLJBJFjn7qOTDkIkH9TiF9mBqE76L6Na6bPMbJ9o+SAzmx7VQ3/azAZGy3c3s8lRDf3XzOzQaFPtzezWqK7+1OhOXJFEKPGLBF0adPX8a8Znte4+EriJUD0S4A/Ane4+CrgbuDFafiPwnIca+gcR7sSEUDv9j+4+HPgEODPWoxHZAd25KwKY2afu3q2R5ZWEh6EsiQrFfeju5Wa2glADfVO0vNrde5nZcqC/u2/I2MYgYJqHB2pgZhOBEnf/eR4OTeRz1OIX2TlvYr45NmTMb0HX1yRBSvwiO/evGa8zo/kXCRUkAc4Cno/mnwb+DbY9RKVHvoIUyZZaHSJBl+hJV3WecPe6IZ09zex1Qqt9QrTsIuAOM7scWA6cHy2/GLjFzL5JaNn/G6GSqkjBUB+/yA5EffwV7r4i6VhEckVdPSIiKaMWv4hIyqjFLyKSMkr8IiIpo8QvIpIySvwiIimjxC8ikjL/CxvDUwPaYYuaAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import *\n",
    "\n",
    "# 定义模型和损失函数\n",
    "class Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = nn.Linear(1, 32)\n",
    "        self.hidden2 = nn.Linear(32, 32)\n",
    "        self.output = nn.Linear(32, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.hidden1(x))\n",
    "        x = torch.relu(self.hidden2(x))\n",
    "        return self.output(x)\n",
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "# 生成随机数据\n",
    "np.random.seed(0)\n",
    "n_samples = 200\n",
    "x = np.linspace(-5, 5, n_samples)\n",
    "y = 0.3 * (x ** 2) + np.random.randn(n_samples)\n",
    "\n",
    "# 转换为Tensor\n",
    "x = torch.unsqueeze(torch.from_numpy(x).float(), 1)\n",
    "y = torch.unsqueeze(torch.from_numpy(y).float(), 1)\n",
    "\n",
    "names = [\"momentum\", \"sgd\"] # 一个使用动量法，一个不使用\n",
    "losses = [[], []]\n",
    "\n",
    "# 超参数\n",
    "learning_rate = 0.0005\n",
    "n_epochs = 500\n",
    "\n",
    "# 分别训练\n",
    "for i in range(len(names)):\n",
    "    model = Model()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9 if i == 0 else 0) # 一个使用动量法，一个不使用\n",
    "    for epoch in range(n_epochs):\n",
    "        optimizer.zero_grad()\n",
    "        out = model(x)\n",
    "        loss = loss_fn(out, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses[i].append(loss.item())\n",
    "\n",
    "# 绘制损失值的变化趋势\n",
    "plt.figure()\n",
    "plt.plot(losses[0], 'r-', label='momentum')\n",
    "plt.plot(losses[1], 'b-', label='sgd')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training loss')\n",
    "plt.ylim((0, 12))\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上述代码中，我们使用了 torch.optim.SGD 类来实现动量法。我们传入了网络的参数以及学习率和动量参数，然后在训练循环中调用 optimizer.step() 来更新参数。通常来说，使用动量法可以帮助优化算法更快地收敛到最优解。它通过给梯度添加一个“动量”来帮助优化算法更快地摆脱局部最优解，并最终收敛到全局最优解。但是要注意，动量法仅仅是在某些情况下可以比其他优化算法更快地收敛，可以让你的训练更快地完成。换句话说，仅仅使用动量法并不能保证优于其他算法。动量法通过维护上一次更新的加速度来减少梯度下降时的震荡，但这同样需要适当的学习率。如果学习率过小，那么即使使用动量法也会出现震荡，因为学习率过小会导致更新过慢。上面的例子就可能出现这种情况。你可以尝试调整不同的参数改善不同方法的效果。\n",
    "\n",
    "**梗直哥提示：上面例子暴露的问题也显示了对深度学习算法而言，仅仅知道原理的皮毛是不够的，还要深刻领会其中的思想，掌握调参的经验，才能无往不胜。再好的模型，也得看会不会用，用的是不是地方。这就如同武器一样，光靠它不行，还要看手拿武器的人能不能把它用好。如果你对这方面的故事和经验体会感兴趣，希望进阶学习，欢迎来到我的课堂（微信：gengzhige99）。**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Next 6-7 AdaGrad算法](./6-7%20AdaGrad算法.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
